{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🔥 Diffusion Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own diffusion model on the Oxford flowers dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bacc8be-c1d2-4108-aa7c-7453dbc777a6",
   "metadata": {},
   "source": [
    "The code is adapted from the excellent ['Denoising Diffusion Implicit Models' tutorial](https://keras.io/examples/generative/ddim/) created by András Béres available on the Keras website."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use(\"seaborn-v0_8-colorblind\")\n",
    "\n",
    "import math\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import (\n",
    "    layers,\n",
    "    models,\n",
    "    optimizers,\n",
    "    utils,\n",
    "    callbacks,\n",
    "    metrics,\n",
    "    losses,\n",
    "    activations,\n",
    ")\n",
    "\n",
    "from notebooks.utils import display, sample_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 64\n",
    "BATCH_SIZE = 64\n",
    "DATASET_REPETITIONS = 5\n",
    "LOAD_MODEL = False\n",
    "\n",
    "NOISE_EMBEDDING_SIZE = 32\n",
    "PLOT_DIFFUSION_STEPS = 20\n",
    "\n",
    "# optimization\n",
    "EMA = 0.999\n",
    "LEARNING_RATE = 1e-3\n",
    "WEIGHT_DECAY = 1e-4\n",
    "EPOCHS = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d4f5e63-e36a-4dc8-9f03-cb29c1fa5290",
   "metadata": {},
   "source": [
    "## 1. Prepare the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d089c317-2d66-4631-81f8-1a8230635845",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "train_data = utils.image_dataset_from_directory(\n",
    "    \"/app/data/pytorch-challange-flower-dataset/dataset\",\n",
    "    labels=None,\n",
    "    image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
    "    batch_size=None,\n",
    "    shuffle=True,\n",
    "    seed=42,\n",
    "    interpolation=\"bilinear\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20697102-8c8d-4582-88d4-f8e2af84e060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "def preprocess(img):\n",
    "    img = tf.cast(img, \"float32\") / 255.0\n",
    "    return img\n",
    "\n",
    "\n",
    "train = train_data.map(lambda x: preprocess(x))\n",
    "train = train.repeat(DATASET_REPETITIONS)\n",
    "train = train.batch(BATCH_SIZE, drop_remainder=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e1a420-699e-4869-8d10-3c049dbad030",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some items of clothing from the training set\n",
    "train_sample = sample_batch(train)\n",
    "display(train_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f53945d9-b7c5-49d0-a356-bcf1d1e1798b",
   "metadata": {},
   "source": [
    "### 1.1 Diffusion schedules <a name=\"diffusion_schedules\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "330083c0-642e-4745-b160-1938493152da",
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_diffusion_schedule(diffusion_times):\n",
    "    min_rate = 0.0001\n",
    "    max_rate = 0.02\n",
    "    betas = min_rate + diffusion_times * (max_rate - min_rate)\n",
    "    alphas = 1 - betas\n",
    "    alpha_bars = tf.math.cumprod(alphas)\n",
    "    signal_rates = tf.sqrt(alpha_bars)\n",
    "    noise_rates = tf.sqrt(1 - alpha_bars)\n",
    "    return noise_rates, signal_rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c08bb63-148f-4813-96c6-adec54d8956e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cosine_diffusion_schedule(diffusion_times):\n",
    "    signal_rates = tf.cos(diffusion_times * math.pi / 2)\n",
    "    noise_rates = tf.sin(diffusion_times * math.pi / 2)\n",
    "    return noise_rates, signal_rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec7e1557-4ea5-4cd8-8f41-66a99ee3898b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def offset_cosine_diffusion_schedule(diffusion_times):\n",
    "    min_signal_rate = 0.02\n",
    "    max_signal_rate = 0.95\n",
    "    start_angle = tf.acos(max_signal_rate)\n",
    "    end_angle = tf.acos(min_signal_rate)\n",
    "\n",
    "    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)\n",
    "\n",
    "    signal_rates = tf.cos(diffusion_angles)\n",
    "    noise_rates = tf.sin(diffusion_angles)\n",
    "\n",
    "    return noise_rates, signal_rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea5f3c3-4797-4b58-b559-45589f59aed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 1000\n",
    "diffusion_times = tf.convert_to_tensor([x / T for x in range(T)])\n",
    "linear_noise_rates, linear_signal_rates = linear_diffusion_schedule(\n",
    "    diffusion_times\n",
    ")\n",
    "cosine_noise_rates, cosine_signal_rates = cosine_diffusion_schedule(\n",
    "    diffusion_times\n",
    ")\n",
    "(\n",
    "    offset_cosine_noise_rates,\n",
    "    offset_cosine_signal_rates,\n",
    ") = offset_cosine_diffusion_schedule(diffusion_times)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e2b7b5a-c79c-4728-8778-10656541dc0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(\n",
    "    diffusion_times, linear_signal_rates**2, linewidth=1.5, label=\"linear\"\n",
    ")\n",
    "plt.plot(\n",
    "    diffusion_times, cosine_signal_rates**2, linewidth=1.5, label=\"cosine\"\n",
    ")\n",
    "plt.plot(\n",
    "    diffusion_times,\n",
    "    offset_cosine_signal_rates**2,\n",
    "    linewidth=1.5,\n",
    "    label=\"offset_cosine\",\n",
    ")\n",
    "\n",
    "plt.xlabel(\"t/T\", fontsize=12)\n",
    "plt.ylabel(r\"$\\bar{\\alpha_t}$ (signal)\", fontsize=12)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92eac1df-3a74-4432-9d4d-0ccfbb5dc24e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(\n",
    "    diffusion_times, linear_noise_rates**2, linewidth=1.5, label=\"linear\"\n",
    ")\n",
    "plt.plot(\n",
    "    diffusion_times, cosine_noise_rates**2, linewidth=1.5, label=\"cosine\"\n",
    ")\n",
    "plt.plot(\n",
    "    diffusion_times,\n",
    "    offset_cosine_noise_rates**2,\n",
    "    linewidth=1.5,\n",
    "    label=\"offset_cosine\",\n",
    ")\n",
    "\n",
    "plt.xlabel(\"t/T\", fontsize=12)\n",
    "plt.ylabel(r\"$1-\\bar{\\alpha_t}$ (noise)\", fontsize=12)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6598d401-867d-450c-a3ba-6823bf309456",
   "metadata": {},
   "source": [
    "## 2. Build the model <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b305c3a-75c5-4b02-9130-79b661fc63cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sinusoidal_embedding(x):\n",
    "    frequencies = tf.exp(\n",
    "        tf.linspace(\n",
    "            tf.math.log(1.0),\n",
    "            tf.math.log(1000.0),\n",
    "            NOISE_EMBEDDING_SIZE // 2,\n",
    "        )\n",
    "    )\n",
    "    angular_speeds = 2.0 * math.pi * frequencies\n",
    "    embeddings = tf.concat(\n",
    "        [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3\n",
    "    )\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9711f6fb-6b43-4e79-9450-faef9f2afdad",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_list = []\n",
    "for y in np.arange(0, 1, 0.01):\n",
    "    embedding_list.append(sinusoidal_embedding(np.array([[[[y]]]]))[0][0][0])\n",
    "embedding_array = np.array(np.transpose(embedding_list))\n",
    "fig, ax = plt.subplots()\n",
    "ax.set_xticks(\n",
    "    np.arange(0, 100, 10), labels=np.round(np.arange(0.0, 1.0, 0.1), 1)\n",
    ")\n",
    "ax.set_ylabel(\"embedding dimension\", fontsize=8)\n",
    "ax.set_xlabel(\"noise variance\", fontsize=8)\n",
    "plt.pcolor(embedding_array, cmap=\"coolwarm\")\n",
    "plt.colorbar(orientation=\"horizontal\", label=\"embedding value\")\n",
    "ax.imshow(embedding_array, interpolation=\"nearest\", origin=\"lower\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ac9c2c-c5bc-4533-b78f-0c878a33893e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ResidualBlock(width):\n",
    "    def apply(x):\n",
    "        input_width = x.shape[3]\n",
    "        if input_width == width:\n",
    "            residual = x\n",
    "        else:\n",
    "            residual = layers.Conv2D(width, kernel_size=1)(x)\n",
    "        x = layers.BatchNormalization(center=False, scale=False)(x)\n",
    "        x = layers.Conv2D(\n",
    "            width, kernel_size=3, padding=\"same\", activation=activations.swish\n",
    "        )(x)\n",
    "        x = layers.Conv2D(width, kernel_size=3, padding=\"same\")(x)\n",
    "        x = layers.Add()([x, residual])\n",
    "        return x\n",
    "\n",
    "    return apply\n",
    "\n",
    "\n",
    "def DownBlock(width, block_depth):\n",
    "    def apply(x):\n",
    "        x, skips = x\n",
    "        for _ in range(block_depth):\n",
    "            x = ResidualBlock(width)(x)\n",
    "            skips.append(x)\n",
    "        x = layers.AveragePooling2D(pool_size=2)(x)\n",
    "        return x\n",
    "\n",
    "    return apply\n",
    "\n",
    "\n",
    "def UpBlock(width, block_depth):\n",
    "    def apply(x):\n",
    "        x, skips = x\n",
    "        x = layers.UpSampling2D(size=2, interpolation=\"bilinear\")(x)\n",
    "        for _ in range(block_depth):\n",
    "            x = layers.Concatenate()([x, skips.pop()])\n",
    "            x = ResidualBlock(width)(x)\n",
    "        return x\n",
    "\n",
    "    return apply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8455e5d-2cc7-4d72-94e1-412a0f5e69b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the U-Net\n",
    "\n",
    "noisy_images = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))\n",
    "x = layers.Conv2D(32, kernel_size=1)(noisy_images)\n",
    "\n",
    "noise_variances = layers.Input(shape=(1, 1, 1))\n",
    "noise_embedding = layers.Lambda(sinusoidal_embedding)(noise_variances)\n",
    "noise_embedding = layers.UpSampling2D(size=IMAGE_SIZE, interpolation=\"nearest\")(\n",
    "    noise_embedding\n",
    ")\n",
    "\n",
    "x = layers.Concatenate()([x, noise_embedding])\n",
    "\n",
    "skips = []\n",
    "\n",
    "x = DownBlock(32, block_depth=2)([x, skips])\n",
    "x = DownBlock(64, block_depth=2)([x, skips])\n",
    "x = DownBlock(96, block_depth=2)([x, skips])\n",
    "\n",
    "x = ResidualBlock(128)(x)\n",
    "x = ResidualBlock(128)(x)\n",
    "\n",
    "x = UpBlock(96, block_depth=2)([x, skips])\n",
    "x = UpBlock(64, block_depth=2)([x, skips])\n",
    "x = UpBlock(32, block_depth=2)([x, skips])\n",
    "\n",
    "x = layers.Conv2D(3, kernel_size=1, kernel_initializer=\"zeros\")(x)\n",
    "\n",
    "unet = models.Model([noisy_images, noise_variances], x, name=\"unet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f15745f6-c453-4342-af50-cd4fbf4556f3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class DiffusionModel(models.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.normalizer = layers.Normalization()\n",
    "        self.network = unet\n",
    "        self.ema_network = models.clone_model(self.network)\n",
    "        self.diffusion_schedule = offset_cosine_diffusion_schedule\n",
    "\n",
    "    def compile(self, **kwargs):\n",
    "        super().compile(**kwargs)\n",
    "        self.noise_loss_tracker = metrics.Mean(name=\"n_loss\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [self.noise_loss_tracker]\n",
    "\n",
    "    def denormalize(self, images):\n",
    "        images = self.normalizer.mean + images * self.normalizer.variance**0.5\n",
    "        return tf.clip_by_value(images, 0.0, 1.0)\n",
    "\n",
    "    def denoise(self, noisy_images, noise_rates, signal_rates, training):\n",
    "        if training:\n",
    "            network = self.network\n",
    "        else:\n",
    "            network = self.ema_network\n",
    "        pred_noises = network(\n",
    "            [noisy_images, noise_rates**2], training=training\n",
    "        )\n",
    "        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates\n",
    "\n",
    "        return pred_noises, pred_images\n",
    "\n",
    "    def reverse_diffusion(self, initial_noise, diffusion_steps):\n",
    "        num_images = initial_noise.shape[0]\n",
    "        step_size = 1.0 / diffusion_steps\n",
    "        current_images = initial_noise\n",
    "        for step in range(diffusion_steps):\n",
    "            diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size\n",
    "            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)\n",
    "            pred_noises, pred_images = self.denoise(\n",
    "                current_images, noise_rates, signal_rates, training=False\n",
    "            )\n",
    "            next_diffusion_times = diffusion_times - step_size\n",
    "            next_noise_rates, next_signal_rates = self.diffusion_schedule(\n",
    "                next_diffusion_times\n",
    "            )\n",
    "            current_images = (\n",
    "                next_signal_rates * pred_images + next_noise_rates * pred_noises\n",
    "            )\n",
    "        return pred_images\n",
    "\n",
    "    def generate(self, num_images, diffusion_steps, initial_noise=None):\n",
    "        if initial_noise is None:\n",
    "            initial_noise = tf.random.normal(\n",
    "                shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, 3)\n",
    "            )\n",
    "        generated_images = self.reverse_diffusion(\n",
    "            initial_noise, diffusion_steps\n",
    "        )\n",
    "        generated_images = self.denormalize(generated_images)\n",
    "        return generated_images\n",
    "\n",
    "    def train_step(self, images):\n",
    "        images = self.normalizer(images, training=True)\n",
    "        noises = tf.random.normal(shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))\n",
    "\n",
    "        diffusion_times = tf.random.uniform(\n",
    "            shape=(BATCH_SIZE, 1, 1, 1), minval=0.0, maxval=1.0\n",
    "        )\n",
    "        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)\n",
    "\n",
    "        noisy_images = signal_rates * images + noise_rates * noises\n",
    "\n",
    "        with tf.GradientTape() as tape:\n",
    "            # train the network to separate noisy images to their components\n",
    "            pred_noises, pred_images = self.denoise(\n",
    "                noisy_images, noise_rates, signal_rates, training=True\n",
    "            )\n",
    "\n",
    "            noise_loss = self.loss(noises, pred_noises)  # used for training\n",
    "\n",
    "        gradients = tape.gradient(noise_loss, self.network.trainable_weights)\n",
    "        self.optimizer.apply_gradients(\n",
    "            zip(gradients, self.network.trainable_weights)\n",
    "        )\n",
    "\n",
    "        self.noise_loss_tracker.update_state(noise_loss)\n",
    "\n",
    "        for weight, ema_weight in zip(\n",
    "            self.network.weights, self.ema_network.weights\n",
    "        ):\n",
    "            ema_weight.assign(EMA * ema_weight + (1 - EMA) * weight)\n",
    "\n",
    "        return {m.name: m.result() for m in self.metrics}\n",
    "\n",
    "    def test_step(self, images):\n",
    "        images = self.normalizer(images, training=False)\n",
    "        noises = tf.random.normal(shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))\n",
    "        diffusion_times = tf.random.uniform(\n",
    "            shape=(BATCH_SIZE, 1, 1, 1), minval=0.0, maxval=1.0\n",
    "        )\n",
    "        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)\n",
    "        noisy_images = signal_rates * images + noise_rates * noises\n",
    "        pred_noises, pred_images = self.denoise(\n",
    "            noisy_images, noise_rates, signal_rates, training=False\n",
    "        )\n",
    "        noise_loss = self.loss(noises, pred_noises)\n",
    "        self.noise_loss_tracker.update_state(noise_loss)\n",
    "\n",
    "        return {m.name: m.result() for m in self.metrics}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fe9df93-c04f-47a9-8686-747e1aba1cfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "ddm = DiffusionModel()\n",
    "ddm.normalizer.adapt(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07be3d70-53be-492b-af23-2432d73405e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    ddm.built = True\n",
    "    ddm.load_weights(\"./checkpoint/checkpoint.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f392424-45a9-49cc-8ea0-c1bec9064d74",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 3.Train the model <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8da961bb-f8ed-4eb3-ae6d-2af32a3c8401",
   "metadata": {},
   "outputs": [],
   "source": [
    "ddm.compile(\n",
    "    optimizer=optimizers.experimental.AdamW(\n",
    "        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
    "    ),\n",
    "    loss=losses.mean_absolute_error,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96e012b8-1fe7-40e9-9fb1-3fb2c3cbcb23",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# run training and plot generated images periodically\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint/checkpoint.ckpt\",\n",
    "    save_weights_only=True,\n",
    "    save_freq=\"epoch\",\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img):\n",
    "        self.num_img = num_img\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        generated_images = self.model.generate(\n",
    "            num_images=self.num_img,\n",
    "            diffusion_steps=PLOT_DIFFUSION_STEPS,\n",
    "        ).numpy()\n",
    "        display(\n",
    "            generated_images,\n",
    "            save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "        )\n",
    "\n",
    "\n",
    "image_generator_callback = ImageGenerator(num_img=10)\n",
    "\n",
    "ddm.fit(\n",
    "    train,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[\n",
    "        model_checkpoint_callback,\n",
    "        tensorboard_callback,\n",
    "        image_generator_callback,\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "180fb0a1-ed16-47c2-b326-ad66071cd6e2",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 4. Inference <a name=\"inference\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84db7dcc-313d-41b6-9357-c403ea97add3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate some novel images of flowers\n",
    "generated_images = ddm.generate(num_images=10, diffusion_steps=20).numpy()\n",
    "display(generated_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe0d6859-652e-421a-974b-75ce3c919244",
   "metadata": {},
   "outputs": [],
   "source": [
    "# View improvement over greater number of diffusion steps\n",
    "for diffusion_steps in list(np.arange(1, 6, 1)) + [20] + [100]:\n",
    "    tf.random.set_seed(42)\n",
    "    generated_images = ddm.generate(\n",
    "        num_images=10,\n",
    "        diffusion_steps=diffusion_steps,\n",
    "    ).numpy()\n",
    "    display(generated_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2254193-c1ab-43fa-97f9-976723898abc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Interpolation between two points in the latent space\n",
    "tf.random.set_seed(100)\n",
    "\n",
    "\n",
    "def spherical_interpolation(a, b, t):\n",
    "    return np.sin(t * math.pi / 2) * a + np.cos(t * math.pi / 2) * b\n",
    "\n",
    "\n",
    "for i in range(5):\n",
    "    a = tf.random.normal(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))\n",
    "    b = tf.random.normal(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))\n",
    "    initial_noise = np.array(\n",
    "        [spherical_interpolation(a, b, t) for t in np.arange(0, 1.1, 0.1)]\n",
    "    )\n",
    "    generated_images = ddm.generate(\n",
    "        num_images=2, diffusion_steps=20, initial_noise=initial_noise\n",
    "    ).numpy()\n",
    "    display(generated_images, n=11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "704bce2f-c567-4382-a959-f5dfe7b538f7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
